{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train and export TensorFlow Serving model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting Spark application\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<tr><th>ID</th><th>YARN Application ID</th><th>Kind</th><th>State</th><th>Spark UI</th><th>Driver log</th><th>Current session?</th></tr><tr><td>0</td><td>application_1560176614871_0001</td><td>pyspark</td><td>idle</td><td><a target=\"_blank\" href=\"http://hopsworks0.logicalclocks.com:8088/proxy/application_1560176614871_0001/\">Link</a></td><td><a target=\"_blank\" href=\"http://hopsworks0.logicalclocks.com:8042/node/containerlogs/container_e01_1560176614871_0001_01_000001/demo_deep_learning_admin000__meb10000\">Link</a></td><td>✔</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SparkSession available as 'spark'.\n"
     ]
    }
   ],
   "source": [
    "def wrapper():\n",
    "\n",
    "    import os\n",
    "    import sys\n",
    "\n",
    "    # This is a placeholder for a Google-internal import.\n",
    "\n",
    "    import tensorflow as tf\n",
    "\n",
    "    #import mnist_input_data\n",
    "    from tensorflow.examples.tutorials.mnist import input_data as mnist_input_data\n",
    "\n",
    "    from hops import serving\n",
    "\n",
    "    training_iteration=1000\n",
    "    model_version=1\n",
    "    work_dir=os.getcwd()\n",
    "\n",
    "    # Train model\n",
    "    print('Training model...')\n",
    "    mnist = mnist_input_data.read_data_sets(work_dir, one_hot=True)\n",
    "    sess = tf.InteractiveSession()\n",
    "    serialized_tf_example = tf.placeholder(tf.string, name='tf_example')\n",
    "    feature_configs = {'x': tf.FixedLenFeature(shape=[784], dtype=tf.float32),}\n",
    "    tf_example = tf.parse_example(serialized_tf_example, feature_configs)\n",
    "    x = tf.identity(tf_example['x'], name='x')  # use tf.identity() to assign name\n",
    "    y_ = tf.placeholder('float', shape=[None, 10])\n",
    "    w = tf.Variable(tf.zeros([784, 10]))\n",
    "    b = tf.Variable(tf.zeros([10]))\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    y = tf.nn.softmax(tf.matmul(x, w) + b, name='y')\n",
    "    cross_entropy = -tf.reduce_sum(y_ * tf.log(y))\n",
    "    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)\n",
    "    values, indices = tf.nn.top_k(y, 10)\n",
    "    table = tf.contrib.lookup.index_to_string_table_from_tensor(\n",
    "        tf.constant([str(i) for i in range(10)]))\n",
    "    prediction_classes = table.lookup(tf.to_int64(indices))\n",
    "    for _ in range(training_iteration):\n",
    "        batch = mnist.train.next_batch(50)\n",
    "        train_step.run(feed_dict={x: batch[0], y_: batch[1]})\n",
    "    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float')) \n",
    "    acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})\n",
    "    print('Training Accuracy: {}'.format(acc))\n",
    "    print('Done training!')\n",
    "\n",
    "    # Export model\n",
    "    # WARNING(break-tutorial-inline-code): The following code snippet is\n",
    "    # in-lined in tutorials, please update tutorial documents accordingly\n",
    "    # whenever code changes.\n",
    "\n",
    "    export_path = work_dir + '/model'\n",
    "    print('Exporting trained model to: {}'.format(export_path))\n",
    "    builder = tf.saved_model.builder.SavedModelBuilder(export_path)\n",
    "\n",
    "    # Build the signature_def_map.\n",
    "    classification_inputs = tf.saved_model.utils.build_tensor_info(\n",
    "        serialized_tf_example)\n",
    "    classification_outputs_classes = tf.saved_model.utils.build_tensor_info(\n",
    "        prediction_classes)\n",
    "    classification_outputs_scores = tf.saved_model.utils.build_tensor_info(values)\n",
    "\n",
    "    classification_signature = (\n",
    "        tf.saved_model.signature_def_utils.build_signature_def(\n",
    "              inputs={\n",
    "                  tf.saved_model.signature_constants.CLASSIFY_INPUTS:\n",
    "                      classification_inputs\n",
    "              },\n",
    "              outputs={\n",
    "                  tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:\n",
    "                      classification_outputs_classes,\n",
    "                  tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:\n",
    "                      classification_outputs_scores\n",
    "              },\n",
    "              method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME))\n",
    "\n",
    "    tensor_info_x = tf.saved_model.utils.build_tensor_info(x)\n",
    "    tensor_info_y = tf.saved_model.utils.build_tensor_info(y)\n",
    "\n",
    "    prediction_signature = (\n",
    "          tf.saved_model.signature_def_utils.build_signature_def(\n",
    "              inputs={'images': tensor_info_x},\n",
    "              outputs={'scores': tensor_info_y},\n",
    "              method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))\n",
    "\n",
    "    legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')\n",
    "    builder.add_meta_graph_and_variables(\n",
    "          sess, [tf.saved_model.tag_constants.SERVING],\n",
    "          signature_def_map={\n",
    "              'predict_images':\n",
    "                  prediction_signature,\n",
    "              tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:\n",
    "                  classification_signature,\n",
    "          },\n",
    "          legacy_init_op=legacy_init_op)\n",
    "\n",
    "    builder.save()\n",
    "\n",
    "    print('Done exporting!')\n",
    "\n",
    "    serving.export(export_path, \"mnist\", 1, overwrite=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished Experiment \n",
      "\n",
      "'hdfs://10.0.2.15:8020/Projects/demo_deep_learning_admin000/Experiments/application_1560176614871_0001/launcher/run.1'"
     ]
    }
   ],
   "source": [
    "from hops import experiment\n",
    "\n",
    "experiment.launch(wrapper, name='mnist serving example')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create Model Serving of Exported Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hops import serving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating a serving for model mnist ...\n",
      "Serving for model mnist successfully created"
     ]
    }
   ],
   "source": [
    "# Create serving\n",
    "model_path=\"/Models/mnist/\"\n",
    "SERVING_NAME=\"mnist\"\n",
    "SERVING_VERSION=1\n",
    "response = serving.create_or_update(model_path, SERVING_NAME, serving_type=\"TENSORFLOW\", \n",
    "                                 model_version=SERVING_VERSION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mnist"
     ]
    }
   ],
   "source": [
    "# List all available servings in the project\n",
    "for s in serving.get_all():\n",
    "    print(s.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'Stopped'"
     ]
    }
   ],
   "source": [
    "# Get serving status\n",
    "serving.get_status(SERVING_NAME)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start Model Serving Server"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting serving with name: mnist...\n",
      "Serving with name: mnist successfully started"
     ]
    }
   ],
   "source": [
    "if serving.get_status(SERVING_NAME) == 'Stopped':\n",
    "    serving.start(SERVING_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "time.sleep(10) # Let the serving startup correctly"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Send Prediction Requests to the Served Model using Hopsworks REST API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "TOPIC_NAME = serving.get_kafka_topic(SERVING_NAME)\n",
    "NUM_FEATURES=784"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'predictions': [[0.0014034, 1.374e-08, 0.200702, 0.278433, 2.62939e-06, 0.515982, 0.000110427, 0.000536466, 0.00279624, 3.41006e-05]]}\n",
      "{'predictions': [[0.00105801, 1.44569e-08, 0.370394, 0.293638, 1.62416e-06, 0.32803, 6.99966e-05, 0.00376709, 0.00293636, 0.000104489]]}\n",
      "{'predictions': [[0.000664825, 4.21654e-08, 0.933665, 0.0423333, 1.82413e-07, 0.020572, 1.23976e-05, 6.47347e-05, 0.0026592, 2.8205e-05]]}\n",
      "{'predictions': [[0.000324636, 2.48305e-08, 0.707619, 0.117427, 7.54794e-06, 0.169898, 2.80385e-05, 8.61907e-05, 0.00433046, 0.000278655]]}\n",
      "{'predictions': [[0.00161419, 1.70108e-08, 0.278404, 0.166106, 3.7164e-06, 0.546742, 0.000127588, 0.000609636, 0.00624632, 0.000147593]]}\n",
      "{'predictions': [[1.95921e-05, 1.39617e-09, 0.938718, 0.0441663, 5.77402e-08, 0.0165412, 1.19215e-05, 3.79244e-06, 0.000537965, 1.43126e-06]]}\n",
      "{'predictions': [[0.00211452, 1.34315e-07, 0.669427, 0.245902, 7.54182e-07, 0.0784495, 0.000155536, 0.00103769, 0.00264703, 0.000265703]]}\n",
      "{'predictions': [[0.00575027, 1.63012e-08, 0.394119, 0.0493724, 2.69317e-07, 0.547179, 0.000184195, 1.54667e-05, 0.00337643, 2.77441e-06]]}\n",
      "{'predictions': [[0.000256665, 6.71677e-09, 0.850342, 0.0763357, 5.00716e-08, 0.0697405, 1.60117e-05, 1.73588e-05, 0.00329028, 1.77844e-06]]}\n",
      "{'predictions': [[0.000110681, 3.03542e-08, 0.255851, 0.235413, 1.1476e-06, 0.507281, 1.76124e-05, 6.91142e-05, 0.00125116, 5.70562e-06]]}\n",
      "{'predictions': [[6.94022e-05, 1.47383e-08, 0.0675792, 0.0302665, 2.20558e-07, 0.900342, 0.00010195, 5.27122e-06, 0.00163475, 2.87142e-07]]}\n",
      "{'predictions': [[0.00117605, 6.95718e-08, 0.455632, 0.266518, 6.63793e-08, 0.27124, 0.000100641, 3.82984e-05, 0.00527822, 1.65679e-05]]}\n",
      "{'predictions': [[0.00190411, 1.39781e-08, 0.322273, 0.569318, 1.88067e-06, 0.094147, 0.00025047, 0.000604175, 0.0114186, 8.25445e-05]]}\n",
      "{'predictions': [[0.000365211, 5.28832e-08, 0.134358, 0.0940924, 2.14606e-07, 0.768545, 6.7301e-05, 3.75368e-05, 0.00253137, 3.29204e-06]]}\n",
      "{'predictions': [[0.000116873, 8.53571e-09, 0.531279, 0.215315, 7.3854e-07, 0.252219, 1.7134e-05, 7.63555e-05, 0.000942414, 3.35974e-05]]}\n",
      "{'predictions': [[0.000397057, 2.3538e-09, 0.196612, 0.0414479, 9.77994e-08, 0.76035, 1.97745e-05, 5.06133e-06, 0.00116435, 3.3667e-06]]}\n",
      "{'predictions': [[0.00133358, 1.45237e-08, 0.310965, 0.106062, 1.29092e-07, 0.57874, 0.000125932, 2.52944e-05, 0.00273446, 1.30079e-05]]}\n",
      "{'predictions': [[5.12985e-05, 3.20011e-08, 0.370492, 0.128075, 1.05856e-05, 0.499149, 0.000176083, 3.39121e-05, 0.00200001, 1.2389e-05]]}\n",
      "{'predictions': [[0.000550011, 3.89663e-08, 0.112325, 0.0339454, 8.92683e-06, 0.841494, 0.000244896, 9.87889e-05, 0.0112907, 4.23188e-05]]}\n",
      "{'predictions': [[0.00230525, 4.8757e-08, 0.707382, 0.183992, 2.81007e-07, 0.0941342, 0.000303966, 6.94327e-05, 0.0117939, 1.8962e-05]]}"
     ]
    }
   ],
   "source": [
    "for i in range(20):\n",
    "    data = {\n",
    "                \"signature_name\": 'predict_images',\n",
    "                \"instances\": [np.random.rand(784).tolist()]\n",
    "            }\n",
    "    response = serving.make_inference_request(SERVING_NAME, data)\n",
    "    print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PySpark",
   "language": "",
   "name": "pysparkkernel"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "python",
    "version": 2
   },
   "mimetype": "text/x-python",
   "name": "pyspark",
   "pygments_lexer": "python2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}